{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5f088c2a",
   "metadata": {},
   "source": [
    "# Preprocessing CODE\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "802e3626",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import sys\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "notebook_dir = os.getcwd()\n",
    "project_root_path = os.path.dirname(notebook_dir)\n",
    "sys.path.insert(0, project_root_path)\n",
    "\n",
    "from src.preprocessing.Derm7pt.concept_preprocessing import encode_image_concepts\n",
    "from src.preprocessing.Derm7pt.label_encoding import one_hot_encode_labels\n",
    "from src.preprocessing.Derm7pt.dataset_utils import filter_concepts_labels\n",
    "from src.preprocessing.Derm7pt.split_data import split_data_by_indices\n",
    "\n",
    "from src import ImageConceptDataset\n",
    "from src.preprocessing import *\n",
    "from src.utils import get_paths, load_Derm_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7ebb376a",
   "metadata": {},
   "outputs": [],
   "source": [
    "paths = get_paths()\n",
    "dataset_handler = load_Derm_dataset(paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "36959f53",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure text files exist\n",
    "if not os.path.exists(paths['labels_file']):\n",
    "    export_image_props_to_text(dataset_handler.df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "19f333b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "verbose = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8e3e0616",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = dataset_handler.df\n",
    "# all_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "41d9e9bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of label columns: 5\n",
      "Found 1011 instances.\n",
      "Created matrix of shape: (1011, 5)\n",
      "Total number of concept columns: 19\n",
      "Found 2013 images.\n",
      "Processing in 63 batches of size 32 (for progress reporting)...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing batches: 100%|██████████| 63/63 [00:17<00:00,  3.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished processing.\n",
      "Successfully transformed: 2013 images.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Get labels and concepts\n",
    "image_labels = one_hot_encode_labels(dataset_handler, paths['mapping_file'], verbose=verbose)\n",
    "concepts_matrix = encode_image_concepts(dataset_handler, paths['mapping_file'], verbose=verbose)\n",
    "\n",
    "# Load and transform images\n",
    "image_tensors, image_paths = load_and_transform_images(paths['dir_images'], paths['mapping_file'], resol=299, use_training_transforms=True, batch_size=32, verbose=verbose)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "cab5f4eb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Labels shape: (2013, 5)\n",
      "Concepts shape: (2013, 19)\n",
      "Image tensors length: 2013\n"
     ]
    }
   ],
   "source": [
    "# Filter if needed\n",
    "if image_labels.shape[0] != len(image_tensors):\n",
    "    filtered_image_labels, filtered_concepts_matrix = filter_concepts_labels(\n",
    "        paths['mapping_file'], image_tensors, image_paths, image_labels, concepts_matrix\n",
    "    )\n",
    "else:\n",
    "    filtered_image_labels, filtered_concepts_matrix = image_labels, concepts_matrix\n",
    "\n",
    "if verbose:\n",
    "    print(\"Labels shape:\", filtered_image_labels.shape)\n",
    "    print(\"Concepts shape:\", filtered_concepts_matrix.shape)\n",
    "    print(\"Image tensors length:\", len(image_tensors))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "fd1462a5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1., 0., 0., 0., 0.],\n",
       "       [1., 0., 0., 0., 0.],\n",
       "       [1., 0., 0., 0., 0.],\n",
       "       ...,\n",
       "       [0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0.],\n",
       "       [0., 0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "filtered_image_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "961ea173",
   "metadata": {},
   "outputs": [],
   "source": [
    "tensors_dict, concepts_dict, labels_dict = split_data_by_indices(\n",
    "    image_tensors, image_paths, filtered_concepts_matrix, filtered_image_labels,\n",
    "    paths, verbose=verbose\n",
    ")\n",
    "\n",
    "train_concept_labels = concepts_dict['train']\n",
    "val_concept_labels = concepts_dict['val']\n",
    "test_concept_labels = concepts_dict['test']\n",
    "\n",
    "train_img_labels = labels_dict['train']\n",
    "val_img_labels = labels_dict['val']\n",
    "test_img_labels = labels_dict['test']\n",
    "\n",
    "train_tensors = tensors_dict['train']\n",
    "val_tensors = tensors_dict['val']\n",
    "test_tensors = tensors_dict['test']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "a9ead3a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# concept processing\n",
    "from src.config import DERM7PT_CONFIG\n",
    "\n",
    "class_level_concepts = compute_class_level_concepts(train_concept_labels, None, train_img_labels)\n",
    "\n",
    "# apply class-level concepts to each instance\n",
    "if True:\n",
    "    train_concept_labels, val_concept_labels, test_concept_labels = apply_class_concepts_to_instances(\n",
    "        class_level_concepts, DERM7PT_CONFIG, train_img_labels, train_concept_labels,\n",
    "        test_img_labels, test_concept_labels, val_img_labels, val_concept_labels)\n",
    "\n",
    "common_concept_indices = select_common_concepts(class_level_concepts, min_class_count=0, CUB=False)\n",
    "train_concept_labels = train_concept_labels[:, common_concept_indices]\n",
    "val_concept_labels = val_concept_labels[:, common_concept_indices]\n",
    "test_concept_labels = test_concept_labels[:, common_concept_indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "13c20eab",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.config import PROJECT_ROOT\n",
    "\n",
    "np.save(os.path.join(PROJECT_ROOT, 'output', 'Derm7pt', 'class_level_concepts.npy'), class_level_concepts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "315dabad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(19,)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "common_concept_indices.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "9f5acdf7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset initialized with 413 pre-sorted items.\n",
      "Dataset initialized with 395 pre-sorted items.\n"
     ]
    }
   ],
   "source": [
    "# CREATE TRAIN AND TEST DATASET\n",
    "train_dataset = ImageConceptDataset(\n",
    "    image_tensors=train_tensors,\n",
    "    concept_labels=train_concept_labels,\n",
    "    image_labels=train_img_labels\n",
    ")\n",
    "\n",
    "test_dataset = ImageConceptDataset(\n",
    "    image_tensors=test_tensors,\n",
    "    concept_labels=test_concept_labels,\n",
    "    image_labels=test_img_labels\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "7b624721",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CREATE DATALOADERS FROM DATASETS\n",
    "batch_size = 64\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=False)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aaa468a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "3.11.9",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
